{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The FineTune/Warm Up option\n",
    "\n",
    "Let's place ourselves in two possible scenarios. \n",
    "\n",
    "1. Let's assume we have run a model and we want to just transfer the learnings (you know...transfer-learning) to another dataset, or simply we have received new data and we do not want to start the training of each component from scratch. Simply, we want to load the pre-trained weights and fine-tune.\n",
    "\n",
    "2. We just want to \"warm up\" individual model components individually before the joined training begins.  \n",
    "\n",
    "This can be done with the `finetune` set of parameters. There are 3 fine-tuning routines:\n",
    "\n",
    "1. Fine-tune all trainable layers at once with a triangular one-cycle learning rate (referred as slanted triangular learning rates in Howard & Ruder 2018)\n",
    "2. Gradual fine-tuning inspired by the work of [Felbo et al., 2017](https://arxiv.org/abs/1708.00524)\n",
    "3. Gradual fine-tuning based on the work of [Howard & Ruder 2018](https://arxiv.org/abs/1801.06146)\n",
    "\n",
    "Currently fine-tunning is only supported without a fully connected head, i.e. if `deephead=None`. In addition, `Felbo` and `Howard` routines only applied, of course, to the `deeptabular`, `deeptext` and `deepimage` models. The `wide` component can also be fine-tuned, but only in an \"all at once\" mode."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune or warm-up all at once\n",
    "\n",
    "Here, the model components will be trained for `finetune_epochs` using a triangular one-cycle learning rate (slanted triangular learning rate) ranging from `finetune_max_lr/10` to `finetune_max_lr` (default is 0.01). 10% of the training steps are used to increase the learning rate which then decreases for the remaining 90%. \n",
    "\n",
    "Here all trainable layers are fine-tuned.\n",
    "\n",
    "Let's have a look to one example. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/envs/widedeep310/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from pytorch_widedeep import Trainer\n",
    "from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor\n",
    "from pytorch_widedeep.models import Wide, TabMlp, TabResnet, WideDeep\n",
    "from pytorch_widedeep.metrics import Accuracy\n",
    "from pytorch_widedeep.datasets import load_adult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>educational_num</th>\n",
       "      <th>marital_status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>gender</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>native_country</th>\n",
       "      <th>income_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>226802</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>89814</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>336951</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>160323</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>103497</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  workclass  fnlwgt     education  educational_num      marital_status  \\\n",
       "0   25    Private  226802          11th                7       Never-married   \n",
       "1   38    Private   89814       HS-grad                9  Married-civ-spouse   \n",
       "2   28  Local-gov  336951    Assoc-acdm               12  Married-civ-spouse   \n",
       "3   44    Private  160323  Some-college               10  Married-civ-spouse   \n",
       "4   18          ?  103497  Some-college               10       Never-married   \n",
       "\n",
       "          occupation relationship   race  gender  capital_gain  capital_loss  \\\n",
       "0  Machine-op-inspct    Own-child  Black    Male             0             0   \n",
       "1    Farming-fishing      Husband  White    Male             0             0   \n",
       "2    Protective-serv      Husband  White    Male             0             0   \n",
       "3  Machine-op-inspct      Husband  Black    Male          7688             0   \n",
       "4                  ?    Own-child  White  Female             0             0   \n",
       "\n",
       "   hours_per_week native_country  income_label  \n",
       "0              40  United-States             0  \n",
       "1              50  United-States             0  \n",
       "2              40  United-States             1  \n",
       "3              40  United-States             1  \n",
       "4              30  United-States             0  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = load_adult(as_frame=True)\n",
    "# For convenience, we'll replace '-' with '_'\n",
    "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n",
    "# binary target\n",
    "df[\"income_label\"] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
    "df.drop(\"income\", axis=1, inplace=True)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define wide, crossed and deep tabular columns\n",
    "wide_cols = [\n",
    "    \"workclass\",\n",
    "    \"education\",\n",
    "    \"marital_status\",\n",
    "    \"occupation\",\n",
    "    \"relationship\",\n",
    "    \"race\",\n",
    "    \"gender\",\n",
    "    \"native_country\",\n",
    "]\n",
    "crossed_cols = [(\"education\", \"occupation\"), (\"native_country\", \"occupation\")]\n",
    "cat_embed_cols = [\n",
    "    \"workclass\",\n",
    "    \"education\",\n",
    "    \"marital_status\",\n",
    "    \"occupation\",\n",
    "    \"relationship\",\n",
    "    \"race\",\n",
    "    \"gender\",\n",
    "    \"capital_gain\",\n",
    "    \"capital_loss\",\n",
    "    \"native_country\",\n",
    "]\n",
    "continuous_cols = [\"age\", \"hours_per_week\"]\n",
    "target_col = \"income_label\"\n",
    "target = df[target_col].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:358: UserWarning: Continuous columns will not be normalised\n",
      "  warnings.warn(\"Continuous columns will not be normalised\")\n"
     ]
    }
   ],
   "source": [
    "# TARGET\n",
    "target = df[target_col].values\n",
    "\n",
    "# WIDE\n",
    "wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n",
    "X_wide = wide_preprocessor.fit_transform(df)\n",
    "\n",
    "# DEEP\n",
    "tab_preprocessor = TabPreprocessor(\n",
    "    cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols\n",
    ")\n",
    "X_tab = tab_preprocessor.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
    "tab_mlp = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    cat_embed_dropout=0.1,\n",
    "    continuous_cols=continuous_cols,\n",
    "    embed_continuous_method=\"standard\",\n",
    "    cont_embed_dim=8,\n",
    "    mlp_hidden_dims=[64, 32],\n",
    "    mlp_dropout=0.2,\n",
    "    mlp_activation=\"leaky_relu\",\n",
    ")\n",
    "model = WideDeep(wide=wide, deeptabular=tab_mlp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model,\n",
    "    objective=\"binary\",\n",
    "    optimizers=torch.optim.Adam(model.parameters(), lr=0.01),\n",
    "    metrics=[Accuracy],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 153/153 [00:02<00:00, 74.26it/s, loss=0.399, metrics={'acc': 0.8163}]\n",
      "valid: 100%|██████████████████████████████████████████████████████████████| 39/39 [00:00<00:00, 91.03it/s, loss=0.296, metrics={'acc': 0.8677}]\n",
      "epoch 2: 100%|████████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 81.31it/s, loss=0.3, metrics={'acc': 0.8614}]\n",
      "valid: 100%|█████████████████████████████████████████████████████████████| 39/39 [00:00<00:00, 106.45it/s, loss=0.285, metrics={'acc': 0.8721}]\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(\n",
    "    X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=2, val_split=0.2, batch_size=256\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save(path=\"models_dir/\", save_state_dict=True, model_filename=\"model_1.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now time goes by...and we want to fine-tune the model to another, new dataset (for example, a dataset that is identical to the one you used to train the previous model but for another country). \n",
    "\n",
    "Here I will use the same dataset just for illustration purposes, but the flow would be identical to that new dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_1 = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
    "tab_mlp_1 = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    cat_embed_dropout=0.1,\n",
    "    continuous_cols=continuous_cols,\n",
    "    embed_continuous_method=\"standard\",\n",
    "    cont_embed_dim=8,\n",
    "    mlp_hidden_dims=[64, 32],\n",
    "    mlp_dropout=0.2,\n",
    "    mlp_activation=\"leaky_relu\",\n",
    ")\n",
    "model_1 = WideDeep(wide=wide, deeptabular=tab_mlp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_1.load_state_dict(torch.load(\"models_dir/model_1.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_1 = Trainer(model_1, objective=\"binary\", metrics=[Accuracy])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training wide for 2 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|███████████████████████████████████████████████████████████| 191/191 [00:01<00:00, 97.37it/s, loss=0.39, metrics={'acc': 0.8152}]\n",
      "epoch 2: 100%|██████████████████████████████████████████████████████████| 191/191 [00:01<00:00, 104.04it/s, loss=0.359, metrics={'acc': 0.824}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training deeptabular for 2 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 191/191 [00:02<00:00, 83.83it/s, loss=0.297, metrics={'acc': 0.8365}]\n",
      "epoch 2: 100%|██████████████████████████████████████████████████████████| 191/191 [00:02<00:00, 82.78it/s, loss=0.283, metrics={'acc': 0.8445}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fine-tuning (or warmup) of individual components completed. Training the whole model for 2 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 191/191 [00:02<00:00, 72.84it/s, loss=0.281, metrics={'acc': 0.8716}]\n",
      "epoch 2: 100%|██████████████████████████████████████████████████████████| 191/191 [00:02<00:00, 77.46it/s, loss=0.273, metrics={'acc': 0.8744}]\n"
     ]
    }
   ],
   "source": [
    "trainer_1.fit(\n",
    "    X_wide=X_wide,\n",
    "    X_tab=X_tab,\n",
    "    target=target,\n",
    "    n_epochs=2,\n",
    "    batch_size=256,\n",
    "    finetune=True,\n",
    "    finetune_epochs=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that, as I describe above, in scenario 2, we can just use this to warm up models before they joined training begins:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
    "tab_mlp = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    cat_embed_dropout=0.1,\n",
    "    continuous_cols=continuous_cols,\n",
    "    embed_continuous_method=\"standard\",\n",
    "    cont_embed_dim=8,\n",
    "    mlp_hidden_dims=[64, 32],\n",
    "    mlp_dropout=0.2,\n",
    "    mlp_activation=\"leaky_relu\",\n",
    ")\n",
    "model = WideDeep(wide=wide, deeptabular=tab_mlp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_2 = Trainer(model, objective=\"binary\", metrics=[Accuracy])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training wide for 2 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:01<00:00, 102.49it/s, loss=0.52, metrics={'acc': 0.7519}]\n",
      "epoch 2: 100%|██████████████████████████████████████████████████████████| 172/172 [00:01<00:00, 98.15it/s, loss=0.381, metrics={'acc': 0.7891}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training deeptabular for 2 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 82.97it/s, loss=0.356, metrics={'acc': 0.8043}]\n",
      "epoch 2: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 80.27it/s, loss=0.295, metrics={'acc': 0.8195}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fine-tuning (or warmup) of individual components completed. Training the whole model for 2 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 77.27it/s, loss=0.291, metrics={'acc': 0.8667}]\n",
      "valid: 100%|██████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 89.57it/s, loss=0.289, metrics={'acc': 0.8665}]\n",
      "epoch 2: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 72.69it/s, loss=0.283, metrics={'acc': 0.8693}]\n",
      "valid: 100%|███████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 91.81it/s, loss=0.284, metrics={'acc': 0.869}]\n"
     ]
    }
   ],
   "source": [
    "trainer_2.fit(\n",
    "    X_wide=X_wide,\n",
    "    X_tab=X_tab,\n",
    "    target=target,\n",
    "    val_split=0.1,\n",
    "    warmup=True,\n",
    "    warmup_epochs=2,\n",
    "    n_epochs=2,\n",
    "    batch_size=256,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune Gradually: The \"felbo\"  and the \"howard\" routines\n",
    "\n",
    "The Felbo routine can be illustrated as follows:\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img width=\"600\" src=\"figures/felbo_routine.png\">\n",
    "</p>\n",
    "\n",
    "**Figure 1.** The figure can be described as follows: fine-tune (or train) the last layer for one epoch using a one cycle triangular learning rate. Then fine-tune the next deeper layer for one epoch, with a learning rate that is a factor of 2.5 lower than the previous learning rate (the 2.5 factor is fixed) while freezing the already warmed up layer(s). Repeat untill all individual layers are warmed. Then warm one last epoch with all warmed layers trainable. The vanishing color gradient in the figure attempts to illustrate the decreasing learning rate. \n",
    "\n",
    "Note that this is not identical to the Fine-Tunning routine described in Felbo et al, 2017, this is why I used the word 'inspired'.\n",
    "\n",
    "The Howard routine can be illustrated as follows:\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img width=\"600\" src=\"figures/howard_routine.png\">\n",
    "</p>\n",
    "\n",
    "**Figure 2.** The figure can be described as follows: fine-tune (or train) the last layer for one epoch using a one cycle triangular learning rate. Then fine-tune the next deeper layer for one epoch, with a learning rate that is a factor of 2.5 lower than the previous learning rate (the 2.5 factor is fixed) while keeping the already warmed up layer(s) trainable. Repeat. The vanishing color gradient in the figure attempts to illustrate the decreasing learning rate. \n",
    "\n",
    "Note that I write \"*fine-tune (or train) the last layer for one epoch [...]*\". However, in practice the user will have to specify the order of the layers to be fine-tuned. This is another reason why I wrote that the fine-tune routines I have implemented are **inspired** by the work of Felbo and Howard and not identical to their implemenations.\n",
    "\n",
    "The `felbo` and `howard` routines can be accessed with via the `fine-tune` parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to explicitly indicate \n",
    "\n",
    "1. That we want fine-tune\n",
    "\n",
    "2. The components that we want to individually fine-tune \n",
    "\n",
    "3. In case of gradual fine-tuning, the routine (\"felbo\" or \"howard\")\n",
    "\n",
    "4. The layers we want to fine-tune. \n",
    "\n",
    "For example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
    "tab_resnet = TabResnet(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    cat_embed_dropout=0.1,\n",
    "    continuous_cols=continuous_cols,\n",
    "    blocks_dims=[200, 200, 200],\n",
    ")\n",
    "model = WideDeep(wide=wide, deeptabular=tab_resnet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WideDeep(\n",
       "  (wide): Wide(\n",
       "    (wide_linear): Embedding(809, 1, padding_idx=0)\n",
       "  )\n",
       "  (deeptabular): Sequential(\n",
       "    (0): TabResnet(\n",
       "      (cat_embed): DiffSizeCatEmbeddings(\n",
       "        (embed_layers): ModuleDict(\n",
       "          (emb_layer_workclass): Embedding(10, 5, padding_idx=0)\n",
       "          (emb_layer_education): Embedding(17, 8, padding_idx=0)\n",
       "          (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)\n",
       "          (emb_layer_occupation): Embedding(16, 7, padding_idx=0)\n",
       "          (emb_layer_relationship): Embedding(7, 4, padding_idx=0)\n",
       "          (emb_layer_race): Embedding(6, 4, padding_idx=0)\n",
       "          (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n",
       "          (emb_layer_capital_gain): Embedding(124, 24, padding_idx=0)\n",
       "          (emb_layer_capital_loss): Embedding(100, 21, padding_idx=0)\n",
       "          (emb_layer_native_country): Embedding(43, 13, padding_idx=0)\n",
       "        )\n",
       "        (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (cont_norm): Identity()\n",
       "      (encoder): DenseResnet(\n",
       "        (dense_resnet): Sequential(\n",
       "          (lin_inp): Linear(in_features=95, out_features=200, bias=False)\n",
       "          (bn_inp): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (block_0): BasicBlock(\n",
       "            (lin1): Linear(in_features=200, out_features=200, bias=False)\n",
       "            (bn1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (dp): Dropout(p=0.1, inplace=False)\n",
       "            (lin2): Linear(in_features=200, out_features=200, bias=False)\n",
       "            (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (block_1): BasicBlock(\n",
       "            (lin1): Linear(in_features=200, out_features=200, bias=False)\n",
       "            (bn1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (dp): Dropout(p=0.1, inplace=False)\n",
       "            (lin2): Linear(in_features=200, out_features=200, bias=False)\n",
       "            (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): Linear(in_features=200, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "let's first train as usual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_3 = Trainer(model, objective=\"binary\", metrics=[Accuracy])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:03<00:00, 54.23it/s, loss=0.382, metrics={'acc': 0.8239}]\n",
      "valid: 100%|██████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 84.72it/s, loss=0.331, metrics={'acc': 0.8526}]\n",
      "epoch 2: 100%|███████████████████████████████████████████████████████████| 172/172 [00:03<00:00, 54.35it/s, loss=0.33, metrics={'acc': 0.8465}]\n",
      "valid: 100%|██████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 68.15it/s, loss=0.312, metrics={'acc': 0.8604}]\n"
     ]
    }
   ],
   "source": [
    "trainer_3.fit(\n",
    "    X_wide=X_wide, X_tab=X_tab, target=target, val_split=0.1, n_epochs=2, batch_size=256\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_3.save(path=\"models_dir\", save_state_dict=True, model_filename=\"model_3.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are going to fine-tune the model components, and in the case of the `deeptabular` component, we will fine-tune the resnet-blocks and the linear layer but NOT the embeddings. \n",
    "\n",
    "For this, we need to access the model component's children: ``deeptabular`` $\\rightarrow$ ``tab_resnet`` $\\rightarrow$ ``dense_resnet`` $\\rightarrow$ ``blocks``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_3 = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
    "tab_resnet_3 = TabResnet(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    cat_embed_dropout=0.1,\n",
    "    continuous_cols=continuous_cols,\n",
    "    blocks_dims=[200, 200, 200],\n",
    ")\n",
    "model_3 = WideDeep(wide=wide, deeptabular=tab_resnet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_3.load_state_dict(torch.load(\"models_dir/model_3.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WideDeep(\n",
       "  (wide): Wide(\n",
       "    (wide_linear): Embedding(809, 1, padding_idx=0)\n",
       "  )\n",
       "  (deeptabular): Sequential(\n",
       "    (0): TabResnet(\n",
       "      (cat_embed): DiffSizeCatEmbeddings(\n",
       "        (embed_layers): ModuleDict(\n",
       "          (emb_layer_workclass): Embedding(10, 5, padding_idx=0)\n",
       "          (emb_layer_education): Embedding(17, 8, padding_idx=0)\n",
       "          (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)\n",
       "          (emb_layer_occupation): Embedding(16, 7, padding_idx=0)\n",
       "          (emb_layer_relationship): Embedding(7, 4, padding_idx=0)\n",
       "          (emb_layer_race): Embedding(6, 4, padding_idx=0)\n",
       "          (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n",
       "          (emb_layer_capital_gain): Embedding(124, 24, padding_idx=0)\n",
       "          (emb_layer_capital_loss): Embedding(100, 21, padding_idx=0)\n",
       "          (emb_layer_native_country): Embedding(43, 13, padding_idx=0)\n",
       "        )\n",
       "        (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (cont_norm): Identity()\n",
       "      (encoder): DenseResnet(\n",
       "        (dense_resnet): Sequential(\n",
       "          (lin_inp): Linear(in_features=95, out_features=200, bias=False)\n",
       "          (bn_inp): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (block_0): BasicBlock(\n",
       "            (lin1): Linear(in_features=200, out_features=200, bias=False)\n",
       "            (bn1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (dp): Dropout(p=0.1, inplace=False)\n",
       "            (lin2): Linear(in_features=200, out_features=200, bias=False)\n",
       "            (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (block_1): BasicBlock(\n",
       "            (lin1): Linear(in_features=200, out_features=200, bias=False)\n",
       "            (bn1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (dp): Dropout(p=0.1, inplace=False)\n",
       "            (lin2): Linear(in_features=200, out_features=200, bias=False)\n",
       "            (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): Linear(in_features=200, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_lin_layer = list(model_3.deeptabular.children())[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Linear(in_features=200, out_features=1, bias=True)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tab_lin_layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_deep_layers = []\n",
    "for n1, c1 in model_3.deeptabular.named_children():\n",
    "    if (\n",
    "        n1 == \"0\"\n",
    "    ):  # 0 is the model component and 1 is always the prediction layer added by the `WideDeep` class\n",
    "        for n2, c2 in c1.named_children():\n",
    "            if n2 == \"encoder\":  # TabResnet\n",
    "                for _, c3 in c2.named_children():\n",
    "                    for n4, c4 in c3.named_children():  # dense_resnet\n",
    "                        if \"block\" in n4:\n",
    "                            tab_deep_layers.append((n4, c4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('block_0',\n",
       "  BasicBlock(\n",
       "    (lin1): Linear(in_features=200, out_features=200, bias=False)\n",
       "    (bn1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "    (dp): Dropout(p=0.1, inplace=False)\n",
       "    (lin2): Linear(in_features=200, out_features=200, bias=False)\n",
       "    (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )),\n",
       " ('block_1',\n",
       "  BasicBlock(\n",
       "    (lin1): Linear(in_features=200, out_features=200, bias=False)\n",
       "    (bn1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "    (dp): Dropout(p=0.1, inplace=False)\n",
       "    (lin2): Linear(in_features=200, out_features=200, bias=False)\n",
       "    (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  ))]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tab_deep_layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now remember, we need to pass ONLY LAYERS (before I included the name for clarity) the layers in WARM UP ORDER, therefore:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_deep_layers = [el[1] for el in tab_deep_layers][::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_layers = [tab_lin_layer] + tab_deep_layers[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Linear(in_features=200, out_features=1, bias=True),\n",
       " BasicBlock(\n",
       "   (lin1): Linear(in_features=200, out_features=200, bias=False)\n",
       "   (bn1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "   (dp): Dropout(p=0.1, inplace=False)\n",
       "   (lin2): Linear(in_features=200, out_features=200, bias=False)\n",
       "   (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       " ),\n",
       " BasicBlock(\n",
       "   (lin1): Linear(in_features=200, out_features=200, bias=False)\n",
       "   (bn1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "   (dp): Dropout(p=0.1, inplace=False)\n",
       "   (lin2): Linear(in_features=200, out_features=200, bias=False)\n",
       "   (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       " )]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tab_layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now simply"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_4 = Trainer(model_3, objective=\"binary\", metrics=[Accuracy])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training wide for 2 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:01<00:00, 95.17it/s, loss=0.504, metrics={'acc': 0.7523}]\n",
      "epoch 2: 100%|███████████████████████████████████████████████████████████| 172/172 [00:01<00:00, 99.83it/s, loss=0.384, metrics={'acc': 0.789}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training deeptabular, layer 1 of 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 72.31it/s, loss=0.317, metrics={'acc': 0.8098}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training deeptabular, layer 2 of 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 65.97it/s, loss=0.312, metrics={'acc': 0.8214}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training deeptabular, layer 3 of 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 63.92it/s, loss=0.306, metrics={'acc': 0.8284}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fine-tuning (or warmup) of individual components completed. Training the whole model for 2 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:03<00:00, 57.26it/s, loss=0.292, metrics={'acc': 0.8664}]\n",
      "valid: 100%|██████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 84.56it/s, loss=0.292, metrics={'acc': 0.8696}]\n",
      "epoch 2: 100%|██████████████████████████████████████████████████████████| 172/172 [00:03<00:00, 53.61it/s, loss=0.282, metrics={'acc': 0.8693}]\n",
      "valid: 100%|██████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 80.59it/s, loss=0.289, metrics={'acc': 0.8719}]\n"
     ]
    }
   ],
   "source": [
    "trainer_4.fit(\n",
    "    X_wide=X_wide,\n",
    "    X_tab=X_tab,\n",
    "    target=target,\n",
    "    val_split=0.1,\n",
    "    finetune=True,\n",
    "    finetune_epochs=2,\n",
    "    deeptabular_gradual=True,\n",
    "    deeptabular_layers=tab_layers,\n",
    "    deeptabular_max_lr=0.01,\n",
    "    n_epochs=2,\n",
    "    batch_size=256,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, there is one more use case I would like to consider. The case where we train only one component and we just want to fine-tune and stop the training afterwards, since there is no joined training. This is a simple as"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_mlp = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    cat_embed_dropout=0.1,\n",
    "    continuous_cols=continuous_cols,\n",
    "    mlp_hidden_dims=[64, 32],\n",
    "    mlp_dropout=0.2,\n",
    "    mlp_activation=\"leaky_relu\",\n",
    ")\n",
    "model = WideDeep(deeptabular=tab_mlp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_5 = Trainer(\n",
    "    model,\n",
    "    objective=\"binary\",\n",
    "    optimizers=torch.optim.Adam(model.parameters(), lr=0.01),\n",
    "    metrics=[Accuracy],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 73.69it/s, loss=0.365, metrics={'acc': 0.8331}]\n",
      "valid: 100%|██████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 92.56it/s, loss=0.299, metrics={'acc': 0.8673}]\n"
     ]
    }
   ],
   "source": [
    "trainer_5.fit(\n",
    "    X_wide=X_wide, X_tab=X_tab, target=target, val_split=0.1, n_epochs=1, batch_size=256\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_5.save(path=\"models_dir\", save_state_dict=True, model_filename=\"model_5.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_mlp_5 = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    cat_embed_dropout=0.1,\n",
    "    continuous_cols=continuous_cols,\n",
    "    mlp_hidden_dims=[64, 32],\n",
    "    mlp_dropout=0.2,\n",
    "    mlp_activation=\"leaky_relu\",\n",
    ")\n",
    "model_5 = WideDeep(deeptabular=tab_mlp_5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_5.load_state_dict(torch.load(\"models_dir/model_5.pt\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "...times go by..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_6 = Trainer(\n",
    "    model_5,\n",
    "    objective=\"binary\",\n",
    "    optimizers=torch.optim.Adam(model.parameters(), lr=0.01),\n",
    "    metrics=[Accuracy],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training deeptabular for 2 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 73.86it/s, loss=0.298, metrics={'acc': 0.8652}]\n",
      "epoch 2: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 75.45it/s, loss=0.286, metrics={'acc': 0.8669}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fine-tuning (or warmup) of individual components completed. Training the whole model for 1 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 172/172 [00:02<00:00, 76.29it/s, loss=0.282, metrics={'acc': 0.8698}]\n",
      "valid: 100%|██████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 84.93it/s, loss=0.281, metrics={'acc': 0.8749}]\n"
     ]
    }
   ],
   "source": [
    "trainer_6.fit(\n",
    "    X_wide=X_wide,\n",
    "    X_tab=X_tab,\n",
    "    target=target,\n",
    "    val_split=0.1,\n",
    "    finetune=True,\n",
    "    finetune_epochs=2,\n",
    "    finetune_max_lr=0.01,\n",
    "    stop_after_finetuning=True,\n",
    "    batch_size=256,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "shutil.rmtree(\"models_dir/\")\n",
    "shutil.rmtree(\"model_weights/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
